from kamal import vision, engine, utils, amalgamation, metrics, callbacks
from kamal.vision import sync_transforms as sT
import pdb
import oneflow
import argparse

parser = argparse.ArgumentParser()
parser.add_argument( '--car_ckpt', default="./ckpt/car_res50_model")
parser.add_argument( '--dog_ckpt', default="./ckpt/dog_res50_model")
parser.add_argument( '--lr', type=float, default=1e-3)

args = parser.parse_args()


def main():
    # 数据集
    car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train', s=0.1)
    car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test')
    aircraft_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train', s=0.1)
    aircraft_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test')

    # 教师/学生
    car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False)
    dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False)
    student = vision.models.classification.resnet50(num_classes=196+120, pretrained=False)

    # 权重参数
    cars_parameters = oneflow.load(args.car_ckpt)
    dogs_parameters = oneflow.load(args.dog_ckpt)

    car_teacher.load_state_dict(cars_parameters)
    dog_teacher.load_state_dict(dogs_parameters)

    train_transform = sT.Compose( [
                            sT.RandomResizedCrop(224),
                            sT.RandomHorizontalFlip(),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    val_transform = sT.Compose( [
                            sT.Resize(256),
                            sT.CenterCrop( 224 ),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )

    car_train_dst.transform = aircraft_train_dst.transform = train_transform
    car_val_dst.transform = aircraft_val_dst.transform = val_transform

    car_metric = metrics.MetricCompose(metric_dict={ 'car_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, :196], t ) ) } )
    aircraft_metric = metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, 196:], t ) ) } )

    train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst] )
    # pdb.set_trace()
    train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=32, shuffle=True, num_workers=0 )
    car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=32, shuffle=False, num_workers=0 )
    aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=32, shuffle=False, num_workers=0 )

    car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric )
    aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric )

    
    TOTAL_ITERS=len(train_loader) * 100
    device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' )
    optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4)
    sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )
    trainer = amalgamation.LayerWiseAmalgamator( 
        logger=utils.logger.get_logger('layerwise-ka'), 
        # tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) 
    )
    
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr')))
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=[
            callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='ttl_car'),
            callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='ttl_aircraft'),
        ] )
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    layer_groups = []
    layer_channels = []
    for stu_block, car_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules() ):
        if isinstance( stu_block, oneflow.nn.Conv2d ):
            layer_groups.append( [ stu_block, car_block, aircraft_block ] )
            layer_channels.append( [ stu_block.out_channels, car_block.out_channels, aircraft_block.out_channels ] )
    
    trainer.setup( student=student, 
                   teachers=[car_teacher, dog_teacher],
                   layer_groups=layer_groups,
                   layer_channels=layer_channels,
                   dataloader=train_loader,
                   optimizer=optim,
                   device=device,
                   weights=[1., 1., 1.] )
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)

if __name__=='__main__':
    main()